# -*- encoding: utf-8 -*-
"""
利用历史数据测试交易策略

回测步骤：
1. 读取 OHLC 数据；
2. 对 OHLC 进行指标运算；
3. 策略根据指标向量决定买卖；
4. 发给模拟的”交易所“进行交易；
5. 最后，统计结果。

回测框架：Pandas + Numpy

代码划分：
交易所类（ExchangeAPI）：负责账户的资金和仓位，以及进行模拟交易
策略类（StrategyAPI）：负责根据市场信息生成指标，根据指标决定买卖
回测类框架（BackTest）：包含一个策略类和一个交易所类，负责迭代地对每个数据点调用策略执行
     - 输入：
        - OHLCV数据，相当于历史交易数据
        - 初始资金
        - 手续费率
        - 交易所类
        - 策略类
    - 输出
        - 最后剩余市值
"""
from os import path
import pandas as pd
import numpy as np
import abc
from typing import Callable


def assert_msg(condition, msg):
    if not condition:
        raise Exception(msg)


def read_file(filename):
    # 获得文件绝对路径
    filepath = path.join(path.dirname(__file__), 'datas', filename)

    # 判定文件是否存在
    assert_msg(path.exists(filepath), " 文件不存在 ")

    # 读取 CSV 文件并返回
    return pd.read_csv(filepath,
                       index_col=0,
                       parse_dates=True,
                       infer_datetime_format=True)


BTCUSD = read_file('BTCUSD_GEMINI.csv')
assert_msg(BTCUSD.__len__() > 0, '读取失败')
print(BTCUSD.head())
print(BTCUSD.columns)


def SMA(values, n):
    """
    返回简单滑动平均
    """
    print(pd.Series(values).rolling(n).mean())
    return pd.Series(values).rolling(n).mean()


def crossover(series1, series2) -> bool:
    """
    检查两个序列是否在结尾交叉
    :param series1:  序列 1
    :param series2:  序列 2
    :return:         如果交叉返回 True，反之 False
    """
    return series1[-2] < series2[-2] and series1[-1] > series2[-1]


def next(self, tick):
    # 如果此时快线刚好越过慢线，买入全部
    if crossover(self.sma1[:tick], self.sma2[:tick]):
        self.buy()

    # 如果是慢线刚好越过快线，卖出全部
    elif crossover(self.sma2[:tick], self.sma1[:tick]):
        self.sell()

    # 否则，这个时刻不执行任何操作。
    else:
        pass


class ExchangeAPI:
    def __init__(self, data, cash, commission):
        assert_msg(0 < cash, " 初始现金数量大于 0，输入的现金数量：{}".format(cash))
        assert_msg(0 <= commission <= 0.05, " 合理的手续费率一般不会超过 5%，输入的费率：{}".format(commission))
        self._inital_cash = cash
        self._data = data
        self._commission = commission
        self._position = 0
        self._cash = cash
        self._i = 0

    @property
    def cash(self):
        """
        :return: 返回当前账户现金数量
        """
        return self._cash

    @property
    def position(self):
        """
        :return: 返回当前账户仓位
        """
        return self._position

    @property
    def initial_cash(self):
        """
        :return: 返回初始现金数量
        """
        return self._inital_cash

    @property
    def market_value(self):
        """
        :return: 返回当前市值
        """
        return self._cash + self._position * self.current_price

    @property
    def current_price(self):
        """
        :return: 返回当前市场价格
        """
        return self._data.Close[self._i]

    def buy(self):
        """
        用当前账户剩余资金，按照市场价格全部买入
        """
        self._position = float(self._cash / (self.current_price * (1 + self._commission)))
        self._cash = 0.0

    def sell(self):
        """
        卖出当前账户剩余持仓
        """
        self._cash += float(self._position * self.current_price * (1 - self._commission))
        self._position = 0.0

    def next(self, tick):
        self._i = tick


class Strategy(metaclass=abc.ABCMeta):
    """
    抽象策略类，用于定义交易策略。

    如果要定义自己的策略类，需要继承这个基类，并实现两个抽象方法：
    Strategy.init
    Strategy.next
    """

    def __init__(self, broker, data):
        """
        构造策略对象。

        @params broker:  ExchangeAPI    交易 API 接口，用于模拟交易
        @params data:    list           行情数据数据
        """
        self._indicators = []
        self._broker = broker  # type: _Broker
        self._data = data  # type: _Data
        self._tick = 0

    def I(self, func: Callable, *args) -> np.ndarray:
        """
        计算买卖指标向量。买卖指标向量是一个数组，长度和历史数据对应；
        用于判定这个时间点上需要进行 " 买 " 还是 " 卖 "。

        例如计算滑动平均：
        def init():
            self.sma = self.I(utils.SMA, self.data.Close, N)
        """
        value = func(*args)
        value = np.asarray(value)
        assert_msg(value.shape[-1] == len(self._data.Close), '指示器长度必须和 data 长度相同')

        self._indicators.append(value)
        return value

    @property
    def tick(self):
        return self._tick

    @abc.abstractmethod
    def init(self):
        """
        初始化策略。在策略回测 / 执行过程中调用一次，用于初始化策略内部状态。
        这里也可以预计算策略的辅助参数。比如根据历史行情数据：
        计算买卖的指示器向量；
        训练模型 / 初始化模型参数
        """
        pass

    @abc.abstractmethod
    def next(self, tick):
        """
        步进函数，执行第 tick 步的策略。tick 代表当前的 " 时间 "。比如 data[tick] 用于访问当前的市场价格。
        """
        pass

    def buy(self):
        self._broker.buy()

    def sell(self):
        self._broker.sell()

    @property
    def data(self):
        return self._data


class SmaCross(Strategy):
    # 小窗口 SMA 的窗口大小，用于计算 SMA 快线
    fast = 10

    # 大窗口 SMA 的窗口大小，用于计算 SMA 慢线
    slow = 20

    def init(self):
        # 计算历史上每个时刻的快线和慢线
        self.sma1 = self.I(SMA, self.data.Close, self.fast)
        self.sma2 = self.I(SMA, self.data.Close, self.slow)

    def next(self, tick):
        # 如果此时快线刚好越过慢线，买入全部
        if crossover(self.sma1[:tick], self.sma2[:tick]):
            self.buy()

        # 如果是慢线刚好越过快线，卖出全部
        elif crossover(self.sma2[:tick], self.sma1[:tick]):
            self.sell()

        # 否则，这个时刻不执行任何操作。
        else:
            pass


class Backtest:
    """
    Backtest 回测类，用于读取历史行情数据、执行策略、模拟交易并估计
    收益。

    初始化的时候调用 Backtest.run 来时回测

    instance, or `backtesting.backtesting.Backtest.optimize` to
    optimize it.
    """

    def __init__(self,
                 data: pd.DataFrame,
                 strategy_type: type(Strategy),
                 broker_type: type(ExchangeAPI),
                 cash: float = 10000,
                 commission: float = .0):
        """
        构造回测对象。需要的参数包括：历史数据，策略对象，初始资金数量，手续费率等。
        初始化过程包括检测输入类型，填充数据空值等。

        参数：
        :param data:            pd.DataFrame        pandas Dataframe 格式的历史 OHLCV 数据
        :param broker_type:     type(ExchangeAPI)   交易所 API 类型，负责执行买卖操作以及账户状态的维护
        :param strategy_type:   type(Strategy)      策略类型
        :param cash:            float               初始资金数量
        :param commission:       float               每次交易手续费率。如 2% 的手续费此处为 0.02
        """

        assert_msg(issubclass(strategy_type, Strategy), 'strategy_type 不是一个 Strategy 类型')
        assert_msg(issubclass(broker_type, ExchangeAPI), 'broker_type 不是一个 ExchangeAPI 类型')
        assert_msg(isinstance(commission, float), 'commission 不是浮点数值类型')

        data = data.copy(False)

        # 如果没有 Volumn 列，填充 NaN
        if 'Volume' not in data:
            data['Volume'] = np.nan

        # 验证 OHLC 数据格式
        assert_msg(len(data.columns & {'Open', 'High', 'Low', 'Close', 'Volume'}) == 5,
                   (" 输入的`data`格式不正确，至少需要包含这些列："
                    "'Open', 'High', 'Low', 'Close'"))

        # 检查缺失值
        assert_msg(not data[['Open', 'High', 'Low', 'Close']].max().isnull().any(),
                   ('部分 OHLC 包含缺失值，请去掉那些行或者通过差值填充. '))

        # 如果行情数据没有按照时间排序，重新排序一下
        if not data.index.is_monotonic_increasing:
            data = data.sort_index()

        # 利用数据，初始化交易所对象和策略对象。
        self._data = data  # type: pd.DataFrame
        self._broker = broker_type(data, cash, commission)
        self._strategy = strategy_type(self._broker, self._data)
        self._results = None

    def run(self):
        """
        运行回测，迭代历史数据，执行模拟交易并返回回测结果。
        Run the backtest. Returns `pd.Series` with results and statistics.

        Keyword arguments are interpreted as strategy parameters.
        """
        strategy = self._strategy
        broker = self._broker

        # 策略初始化
        strategy.init()

        # 设定回测开始和结束位置
        start = 100
        end = len(self._data)

        # 回测主循环，更新市场状态，然后执行策略
        for i in range(start, end):
            # 注意要先把市场状态移动到第 i 时刻，然后再执行策略。
            broker.next(i)
            strategy.next(i)

        # 完成策略执行之后，计算结果并返回
        self._results = self._compute_result(broker)
        return self._results

    def _compute_result(self, broker):
        s = pd.Series()
        s['初始市值'] = broker.initial_cash
        s['结束市值'] = broker.market_value
        s['收益'] = broker.market_value - broker.initial_cash
        return s


def main():
    BTCUSD = read_file('BTCUSD_GEMINI.csv')
    ret = Backtest(BTCUSD, SmaCross, ExchangeAPI, 10000.0, 0.02).run()
    print(ret)


if __name__ == '__main__':
    main()
